
configfile: 'config/v1.yaml'

singularity: "docker://alzhang/spectrum-master-rstudio:v1.4"

def get_celltype_outputs(wildcards):
    output_files = expand(
        '{outdir}/evaluate/celltypes/{{de_prob}}_{{down_prob}}_{{seed}}/{{ss_method}}_{{max_genes}}_{{wm_prop}}_{{run}}.tsv'.format(
            outdir=config['outdir']
        ),
        de_prob=config['simulation_settings']['variables']['de_prob'],
        down_prob=config['simulation_settings']['variables']['down_prob'],
        seed=config['simulation_settings']['variables']['seed'],
        max_genes=config['assign_celltype_settings']['max_genes'],
        ss_method=config['assign_celltype_settings']['methods'],
        wm_prop=config['assign_celltype_settings']['wrong_marker_proportion'],
        run=config['assign_celltype_settings']['runs']
    )
    
    return output_files

rule all:
    input:
        expand(
            '{outdir}/simulated_sce/sce_{{de_prob}}_{{down_prob}}_{{seed}}.rds'.format(outdir=config['workdir']),
            de_prob=config['simulation_settings']['variables']['de_prob'],
            down_prob=config['simulation_settings']['variables']['down_prob'],
            seed=config['simulation_settings']['variables']['seed']
        ),
        get_celltype_outputs,


rule create_sce:
    input:
        base_params=config['simulation_settings']['constants']['base_params'],
    output:
        sce='{outdir}/simulated_sce/sce_{{de_prob}}_{{down_prob}}_{{seed}}.rds'.format(outdir=config['workdir']),
    params:
        name='create-sce-{seed}-{de_prob}-{down_prob}',
        random_seed='{seed}',
        sim_model='splat',
        num_cells=config['simulation_settings']['constants']['num_cells'],
        num_groups=config['simulation_settings']['constants']['num_groups'],
        num_batches=config['simulation_settings']['constants']['num_batches'],
        group_probs=config['simulation_settings']['constants']['group_probs'],
        batch_probs=config['simulation_settings']['constants']['batch_probs'],
        de_facloc=config['simulation_settings']['constants']['de_facloc'],
        de_facscale=config['simulation_settings']['constants']['de_facscale'],
        de_nu=config['simulation_settings']['constants']['de_nu'],
        de_min=config['simulation_settings']['constants']['de_min'],
        de_max=config['simulation_settings']['constants']['de_max'],
        batch_facloc=config['simulation_settings']['constants']['batch_facloc'],
        batch_facscale=config['simulation_settings']['constants']['batch_facscale'],
        de_prob='{de_prob}',
        down_prob='{down_prob}',
    log:
        '{logdir}/logs/create_sce/{{de_prob}}_{{down_prob}}/{{seed}}.log'.format(logdir=config['logdir']),
    benchmark:
        '{logdir}/benchmarks/create_sce/{{de_prob}}_{{down_prob}}/{{seed}}.txt'.format(logdir=config['logdir']),
    shell:
        'Rscript R/create_sce.R '
        '--params {input.base_params} '
        '--outfname {output.sce} '
        '--seed {params.random_seed} '
        '--sim_model {params.sim_model} '
        '--num_cells {params.num_cells} '
        '--num_groups {params.num_groups} '
        '--num_batches {params.num_batches} '
        '--group_probs {params.group_probs} '
        '--batch_probs {params.batch_probs} '
        '--de_facloc {params.de_facloc} '
        '--de_facscale {params.de_facscale} '
        '--de_nu {params.de_nu} '
        '--de_min {params.de_min} '
        '--de_max {params.de_max} '
        '--batch_facloc {params.batch_facloc} '
        '--batch_facscale {params.batch_facscale} '
        '--de_prob {params.de_prob} '
        '--down_prob {params.down_prob} '
        '>& {log}'

rule assign_celltypes_sce:
    input:
        sce='{outdir}/simulated_sce/sce_{{de_prob}}_{{down_prob}}_{{seed}}.rds'.format(outdir=config['workdir']),
    output:
        clust_results='{outdir}/assign_celltypes_sce/clusters/{{de_prob}}_{{down_prob}}_{{seed}}/{{ss_method}}_{{max_genes}}_{{wm_prop}}_{{run}}.tsv'.format(outdir=config['outdir']),
        fit_results='{outdir}/assign_celltypes_sce/fits/{{de_prob}}_{{down_prob}}_{{seed}}/{{ss_method}}_{{max_genes}}_{{wm_prop}}_{{run}}.rds'.format(outdir=config['outdir']),
    params:
        name='assign-celltypes-sce-{seed}-{de_prob}-{down_prob}-{ss_method}-{max_genes}-{wm_prop}-{run}',
        conda_env=config['assign_celltype_settings']['conda_env'],
        fc_percentile=config['assign_celltype_settings']['fc_percentile'],
        expr_percentile=config['assign_celltype_settings']['expr_percentile'],
        frac_genes=config['assign_celltype_settings']['frac_genes'],
        ss_method='{ss_method}',
        max_genes='{max_genes}',
        wm_prop='{wm_prop}',
        data_subset=lambda wildcards: config['assign_celltype_settings']['runs'][wildcards.run]['data_celltypes'],
        marker_subset=lambda wildcards: config['assign_celltype_settings']['runs'][wildcards.run]['marker_celltypes'],
    log:
        '{logdir}/logs/assign_celltypes_sce/{{de_prob}}_{{down_prob}}_{{seed}}/{{ss_method}}_{{max_genes}}_{{wm_prop}}_{{run}}.log'.format(logdir=config['logdir']),
    benchmark:
        '{logdir}/benchmarks/assign_celltypes_sce/{{de_prob}}_{{down_prob}}_{{seed}}/{{ss_method}}_{{max_genes}}_{{wm_prop}}_{{run}}.txt'.format(logdir=config['logdir']),
    shell:
        'Rscript R/assign_celltypes.R '
        '--sce {input.sce} '
        '--out_clusters {output.clust_results} '
        '--out_fit {output.fit_results} '
        '--conda_env {params.conda_env} '
        '--fc_percentile {params.fc_percentile} '
        '--expr_percentile {params.expr_percentile} '
        '--frac_genes {params.frac_genes} '
        '--method {params.ss_method} '
        '--max_genes {params.max_genes} '
        '--wrong_prop {params.wm_prop} '
        '--data_subset {params.data_subset} '
        '--marker_subset {params.marker_subset} '
        '>& {log}'

rule evaluate_celltypes:
    input:
        sce='{outdir}/simulated_sce/sce_{{de_prob}}_{{down_prob}}_{{seed}}.rds'.format(outdir=config['workdir']),
        clust_results='{outdir}/assign_celltypes_sce/clusters/{{de_prob}}_{{down_prob}}_{{seed}}/{{ss_method}}_{{max_genes}}_{{wm_prop}}_{{run}}.tsv'.format(outdir=config['outdir']),
    output:
        evaluation_results='{outdir}/evaluate/celltypes/{{de_prob}}_{{down_prob}}_{{seed}}/{{ss_method}}_{{max_genes}}_{{wm_prop}}_{{run}}.tsv'.format(outdir=config['outdir']),
    params:
        name='evaluate-celltypes-{seed}-{de_prob}-{down_prob}-{ss_method}-{max_genes}-{wm_prop}-{run}',
        conda_env=config['assign_celltype_settings']['conda_env'],
        method='{ss_method}',
        marker_subset=lambda wildcards: config['assign_celltype_settings']['runs'][wildcards.run]['marker_celltypes'],
    log:
        '{logdir}/logs/evaluate_celltypes/{{de_prob}}_{{down_prob}}_{{seed}}/{{ss_method}}_{{max_genes}}_{{wm_prop}}_{{run}}.log'.format(logdir=config['logdir']),
    benchmark:
        '{logdir}/benchmarks/evaluate_celltypes/{{de_prob}}_{{down_prob}}_{{seed}}/{{ss_method}}_{{max_genes}}_{{wm_prop}}_{{run}}.txt'.format(logdir=config['logdir']),
    shell:
        'Rscript R/evaluate_cluster.R '
        '--sce {input.sce} '
        '--cluster_file {input.clust_results} '
        '--outfname {output.evaluation_results} '
        '--marker_types {params.marker_subset} '
        '--conda_env {params.conda_env} '
        '--method {params.method} '
        '>& {log}'


    
    